
from utils.DataLoader import DataLoader
import utils
import os
import torch
import torchvision
import random
from torchvision import datasets, transforms
import numpy as np
import copy
from torch.utils.data import ConcatDataset
from sklearn.model_selection import train_test_split
import os
import torch
from torchvision import datasets, transforms
from utils.DataLoader import DataLoader
from datasets import load_dataset, concatenate_datasets, DatasetDict, ClassLabel, load_from_disk, Dataset

# from torch.utils.data import DataLoader
from PIL import Image
import math

#
# class CustomImageFolder(datasets.ImageFolder):
#     def __init__(self, root, transform=None, num_classes=20):
#         super(CustomImageFolder, self).__init__(root, transform=transform)
#
#         # 获取所有类别的名称和索引
#         # classes, class_to_idx = self.find_classes(root)
#         # # 筛选出前num_classes个类别的索引
#         # selected_classes = classes[:num_classes]
#         # print(selected_classes)
#         selected_classes_idx = list(range(num_classes))
#
#         # selected_class_to_idx = {cls: i for i, cls in enumerate(selected_classes)}
#
#         # 筛选出前num_classes个类别的样本
#         self.samples = [(path, cls) for path, cls in self.samples if cls in selected_classes_idx]
#         self.targets = [s[1] for s in self.samples]
#
#
# class SubsetDataset(torch.utils.data.Dataset):
#     def __init__(self, original_dataset, indices):
#         self.dataset = original_dataset
#         self.indices = indices
#         self.bd_data = []
#         self.bd_flag = False
#     def __len__(self):
#         return len(self.indices)
#
#     def __getitem__(self, index):
#         # 如果索引在bd_data范围内，且已经构造bd_data，那么就索引bd_data
#         if index < len(self.bd_data) and self.bd_flag:
#             return self.bd_data[index]
#         # 否则获取原始数据集中对应的索引
#         original_index = self.indices[index]
#
#         # 从原始数据集中获取数据
#         data, target = self.dataset[original_index]
#         return data, target
#
#     def add_trigger(self, bd_maker, attack_portion=0.8):
#         self.bd_flag = True
#         for i in range(int(len(self) * attack_portion)):
#             data = self.dataset[self.indices[i]]
#             data = bd_maker.add_backdoor(data)
#             self.bd_data.append(data)



class DataLoader_domain(DataLoader):
    def __init__(self,
                 batch_size=100,
                 split_num=2,
                 class_num=30,
                 input_require_shape=None,
                 pool_size=None,
                 params=None,
                 recreate=False,
                 *args,
                 **kwargs):

        if params is not None:
            batch_size = params['batch_size']
            split_num = params['split_num'] # TODO: check split num
            class_num = params['class_num']
        # pool_size = split_num // pick_num
        name = f'Domain_pool_6_split_{split_num}_class_{class_num}_batchsize_{batch_size}'
        nickname = None
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = os.path.join(utils.data_folder_path, name)
        save_path = os.path.join(utils.pool_folder_path, f'{name}.npy')

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(
                (224, 224),
                interpolation=transforms.InterpolationMode.BICUBIC,
                antialias=False,
            ),
            transforms.Normalize(
                [0.48145466, 0.4578275, 0.40821073],
                [0.26862954, 0.26130258, 0.27577711]),
        ])
        def trans(examples):
            examples['image'] = [transform(image.convert("RGB")) for image in examples['image']]
            return examples

        if os.path.exists(save_path) and (recreate == False):
            data_loader = np.load(save_path, allow_pickle=True).item()  # 导入对象
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
            # dataset = load_from_disk(os.path.join(utils.data_folder_path, 'Domain', f'{self.target_class_num}'))
            # self.domain_dict = dataset.features['domain']
            # self.label_dict = dataset.features['label']
            # np.save(save_path, self)
            # assert 1==0
        else:
            self.name = name
            self.domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
            self.pool_size = len(self.domains)
            self.input_data_shape = [3, 224, 224]
            self.target_class_num = params['class_num']
            # self.output_size = 345
            self.total_training_number = 0
            self.total_test_number = 0
            self.server_data = {}
            self.server_data_number = {}
            self.statistic = {}

            cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), 'hf-cache')
            local_datas, pool_datas = [], []

            # if os.path.exists(os.path.join(utils.data_folder_path, 'Domain', f'{self.target_class_num}')):
            try:
                dataset = load_from_disk(os.path.join(utils.data_folder_path, 'Domain', f'class_num_{self.target_class_num}'))
            except:
                print(f'DomainNet target class num {self.target_class_num} is not found, now begin to create')
                train_dataset = load_dataset(os.path.join(utils.data_folder_path, 'Domain', 'domain'), split='train')
                train_dataset = train_dataset.filter(lambda example: example['label'] < self.target_class_num, num_proc=16)
                test_dataset = load_dataset(os.path.join(utils.data_folder_path, 'Domain', 'domain'), split='test')
                test_dataset = test_dataset.filter(lambda example: example['label'] < self.target_class_num, num_proc=16)

                dataset = DatasetDict({
                    'train': train_dataset,
                    'test': test_dataset
                })
                dataset.save_to_disk(os.path.join(utils.data_folder_path, 'Domain', f'class_num_{self.target_class_num}'))

            train_dataset, test_dataset = dataset['train'], dataset['test']
            self.domain_dict = train_dataset.features['domain']
            self.label_dict = train_dataset.features['label']
            label_split = math.ceil(self.target_class_num * 0.7)

            for domain in self.domains:
                d_train_data = train_dataset.filter(lambda example: example["domain"] == self.domain_dict.str2int(domain), num_proc=16)
                d_test_data = test_dataset.filter(lambda example: example["domain"] == self.domain_dict.str2int(domain), num_proc=16)

                # to balance data
                print(domain, 'train', len(d_train_data))
                print(domain, 'test', len(d_test_data))

                if len(d_train_data) > 3000: # 保证每个domain的数据在4000内
                    d_train_data = d_train_data.shuffle().select(range(3000)) # 这里没问题
                if len(d_test_data) > 1000:
                    d_test_data = d_test_data.shuffle().select(range(1000))

                if split_num == 2:
                    local_train_data = d_train_data.filter(lambda example: example['label'] < label_split, num_proc=16)
                    local_test_data  = d_test_data.filter(lambda example: example['label'] < label_split, num_proc=16)
                    pool_train_data = d_train_data.filter(lambda example: example['label'] >= label_split, num_proc=16)
                    pool_test_data = d_test_data.filter(lambda example: example['label'] >= label_split, num_proc=16)
                elif split_num == 1:
                    local_train_data = d_train_data
                    local_test_data = d_test_data
                    pool_train_data, pool_test_data = (Dataset.from_dict({"image": [], "label": [], "domain": [], "image_path": []}),
                                                       Dataset.from_dict({"image": [], "label": [], "domain": [], "image_path": []}))  # 根据实际列名和数据类型调整
                else:
                    raise ValueError('Split num must be 2 or 1.')

                local_datas.append([local_train_data, local_test_data])
                pool_datas.append([pool_train_data, pool_test_data])
            def create_data_pool(data_pool):
                for pool_idx in range(self.pool_size):
                    # local_datas[idx] + pool_datas[(idx+1)%self.pool_size]
                    local_training_data = DatasetDict({
                        'train': concatenate_datasets(
                            # [local_datas[pool_idx]['train'], pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['train']]),
                            [local_datas[pool_idx][0], pool_datas[(pool_idx + 1) % self.pool_size][0]]),
                        'test': concatenate_datasets(
                            # [local_datas[pool_idx]['test'], pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['test']])
                            [local_datas[pool_idx][1], pool_datas[(pool_idx + 1) % self.pool_size][1]])

                    })
                    # local_training_data.save_to_disk(os.path.join(file_path, f'local_{pool_idx}_data'))

                    data_pool[pool_idx]['local_training_data'] = local_training_data['train']
                    print(self.domain_dict.int2str(set(local_training_data['train']['domain'])))
                    data_pool[pool_idx]['local_training_domain'] = self.domain_dict.int2str(
                        set(local_training_data['train']['domain']))
                    data_pool[pool_idx]['local_test_data'] = local_training_data['test']
                    data_pool[pool_idx]['local_test_domain'] = self.domain_dict.int2str(
                        set(local_training_data['test']['domain']))
                    print(self.domain_dict.int2str(set(local_training_data['test']['domain'])))
                    data_pool[pool_idx]['local_training_number'] = len(local_training_data['train'])
                    print(len(local_training_data['train']))
                    data_pool[pool_idx]['local_test_number'] = len(local_training_data['test'])
                    print(len(local_training_data['test']))
                    # domain_name: [train_num, test_num]
                    data_pool[pool_idx]['local_statistic'] = {
                        self.domains[pool_idx]: [len(local_datas[pool_idx][0]),
                                                 len(local_datas[pool_idx][1])],
                        self.domains[(pool_idx + 1) % self.pool_size]: [
                            len(pool_datas[(pool_idx + 1) % self.pool_size][0]),
                            len(pool_datas[(pool_idx + 1) % self.pool_size][1])]}

            data_pool = [{} for _ in range(self.pool_size)]
            # local_training local_test
            create_data_pool(data_pool)
            self.data_pool = data_pool
            # print(self.statistic)
            np.save(save_path, self)
        for pool in self.data_pool:
            pool['local_training_data'].set_transform(trans)
            pool['local_test_data'].set_transform(trans)
    def allocate(self, client_list):
        choose_data_pool_item_indices = list(range(self.pool_size))
        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'],
                               data_pool_item['local_statistic'])
